{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 交叉熵实验\n",
    "交叉熵是度量损失的一种计算公式\n",
    "下面一段代码，假设有一个标签labels和一个网络输出值logits。\n",
    "这个实例就是以这两个值进行以下3次实验\n",
    "（1）两次softmax实验：将输出值logits分别进行1次和2次softmax，观察两次的区别及意义\n",
    "（2）观察交叉熵：将步骤（1）中的两个值分别进行softmax_cross_entropy_with_logits，观察他们的区别\n",
    "（3）自建公式实验：将做两次softmax的值放到自建组合的公式里得到正确的值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-56e2b6e2a134>:11: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n",
      "\n",
      "softmax计算结果=\n",
      " [[0.01791432 0.00399722 0.97808844]\n",
      " [0.04980332 0.04506391 0.90513283]]\n",
      "tensorflow交叉熵损失计算结果=\n",
      " [0.02215516 3.0996735 ]\n",
      "自建交叉熵损失计算结果=\n",
      " [0.02215518 3.0996735 ]\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "# 实际测试集的标签one-hot编码\n",
    "labels = [[0, 0, 1],\n",
    "          [0, 1, 0]]\n",
    "# 预测得到的结果\n",
    "logits = [[2, 0.5, 6],\n",
    "          [0.1, 0, 3]]\n",
    "# softmax的计算结果\n",
    "logits_scaled = tf.nn.softmax(logits)\n",
    "# 交叉熵损失函数的计算结果，注意：tensorflow中计算交叉熵损失值时，自带了softmax的计算\n",
    "result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
    "# 交叉熵自建公式计算结果\n",
    "result3 = -tf.reduce_sum(labels*tf.log(logits_scaled), 1)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    print (\"softmax计算结果=\\n\", sess.run(logits_scaled))\n",
    "    print (\"tensorflow交叉熵损失计算结果=\\n\", sess.run(result1))\n",
    "    print (\"自建交叉熵损失计算结果=\\n\", sess.run(result3))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
